import warnings

import torch
from torch import IntTensor


from ModularUtils.ControllerConstants import map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_generated_labels
from ModularUtils.FunctionsConstant import asKey
from ModularUtils.Generators import sample_gumbel


def sampling(Exp, n):
    label_noises = {}
    gumbel_noise = {}
    for name in Exp.label_names:
        label_noises[Exp.exogenous[name]] = torch.randn(n, Exp.NOISE_DIM).to(Exp.DEVICE)
        gumbel_noise[name] = sample_gumbel(Exp, (n, Exp.label_dim[name]))

    latent_noises = {}
    for lat in Exp.confTochild.keys():
        latent_noises[lat] = torch.randn(n, Exp.CONF_NOISE_DIM).to(Exp.DEVICE)  # ?

    return label_noises, latent_noises, gumbel_noise


def check_if_working(Exp, label_generators, posterior_label, posterior_latent, gumbel_noise, evidence, num_noise):
    # z_noise, z_gum = success(Exp, label_generators, posterior_label, posterior_latent, gumbel_noise, evidence, num_noise)
    # posterior_label, posterior_latent, gumbel_noise = reshape_noises(Exp, z_noise, z_gum)

    # while z_noise.shape[0] > 100:
    #     z = success(Exp, label_generators, posterior_label, posterior_latent, gumbel_noise, evidence, z_noise.shape[0])
    #     posterior_label, posterior_latent, gumbel_noise = reshape_noises(Exp, z_noise, z_gum)

    evidence = [Exp.twin_map[var] for var in Exp.cf_evidence]
    gen_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent, {}, evidence,
                                           num_noise, gumbel_noise=gumbel_noise)
    gen_labels = torch.tensor(map_dictfill_to_discrete(Exp, gen_labels_dict, evidence))

    print("1st try")
    # for var in new_cfev:
    obs_v = gen_labels[0:num_noise]
    print(torch.unique(obs_v, return_counts=True, dim=0))
    #
    # print("2nd try")
    #
    # gen_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent, {}, evidence,num_noise)
    # gen_labels = torch.tensor(map_dictfill_to_discrete(Exp, gen_labels_dict, evidence))
    # obs_v = gen_labels[0:num_noise]
    # print(torch.unique(obs_v, return_counts=True, dim=0))


def success(Exp, label_generators, label_noises, latent_noises, gumbel_noise, evidence_list, narray, min_sample):
    n = int(IntTensor.item(torch.max(narray)))
    n = torch.tensor(max(n, min_sample))

    # samples_needed=max(n,100)
    new_evidence_list = []
    for ev in evidence_list:
        new_ev = {}
        for key in ev:
            new_ev[Exp.twin_map[key]] = ev[key]

        new_evidence_list.append(new_ev)

    # generic part
    obs_vars = list(new_evidence_list[0].keys())
    gen_labels_dict = get_generated_labels(Exp, label_generators, label_noises, latent_noises, {}, obs_vars, n,
                                           gumbel_noise=gumbel_noise)  # num of samples
    gen_labels = map_dictfill_to_discrete(Exp, gen_labels_dict, obs_vars)

    # follow the noise sequence how it was defined
    all_noises = {**label_noises, **latent_noises}
    noise_values = []
    gum_values = []
    for label in all_noises:
        noise_values.append(all_noises[label].view(n, 1, -1))

    for label in Exp.label_names:
        gum_values.append(gumbel_noise[label].view(n, -1))

    new_z = torch.cat(noise_values, 1)
    new_gum = torch.cat(gum_values, 1)
    #
    gel_lb= torch.tensor(gen_labels)
    ret = torch.unique(gel_lb, sorted=True, return_inverse=True, return_counts=True, dim=0)
    # print(ret)
    # indices = {idx: [] for idx in range(len(ret[0]))}
    # indices = {idx: [] for idx in range(len(evidence_list))}
    indices = {0:[], 1:[]}
    for ind, rv_ind in enumerate(ret[1]):
        # print(ind, IntTensor.item(rv_ind))
        indices[IntTensor.item(rv_ind)].append(ind)

#what am i doing here?
    res_noise = {}
    res_gum = {}
    for evno, ev in enumerate(evidence_list):
        kev = tuple(ev.values())
        need = int(IntTensor.item(narray[evno]))
        xx1 = indices[evno][0:need]
        xx2 = new_z[xx1]
        res_noise[asKey(ev)] = xx2
        res_gum[asKey(ev)] = new_gum[indices[evno]][0:need]

    return res_noise, res_gum


def reshape_noises(Exp, z_noise, z_gum):
    posterior_label = {}
    posterior_latent = {}
    gum_noise = {}

    st = 0
    en = 0
    for id, label in enumerate(Exp.label_names):
        st = en
        en = st + Exp.label_dim[label]
        gum_noise[label] = z_gum[:, st:en]

    for id, noise in enumerate(Exp.noise_params.keys()):

        if noise in Exp.exogenous.values():
            posterior_label[noise] = z_noise[:, id, :]
        elif noise in Exp.confTochild.keys():
            posterior_latent[noise] = z_noise[:, id, :]

    return posterior_label, posterior_latent, gum_noise


def rejection_sampling_optimized(Exp, label_generators, n, evidence_list, max_rejections=0, warn=100):
    # n=10000
    # n = 10000
    with torch.no_grad():
        rejections = 0
        nodes = Exp.noise_params.keys()
        num_noise = len(nodes)

        z_noise = {}
        z_gum = {}
        for evid in evidence_list:
            z_noise[asKey(evid)] = []
            z_gum[asKey(evid)] = []

        narray = torch.ones(len(evidence_list)) * n

        while torch.count_nonzero(narray) > 0:
            if warn and rejections == warn:
                warnings.warn(f'Exceeded max rejections: {warn}')

            # sampled noise
            n = int(IntTensor.item(torch.max(narray)))
            n = torch.tensor(max(n, 5000))  #minimum sampling
            label_noises, latent_noises, gumbel_noise = sampling(Exp, n)

            res_noise, res_gum = success(Exp, label_generators, label_noises, latent_noises, gumbel_noise,
                                         evidence_list, narray, n)
            temp_sum = 0

            for idx, kev in enumerate(z_noise):
                # kev = tuple(ev.values())
                # kev = evid
                z_noise[kev].append(res_noise[kev])
                z_gum[kev].append(res_gum[kev])
                # min_completed= min(min_completed, len(z_noise[evid]))
                narray[idx] -= res_noise[kev].shape[0]
                temp_sum += res_noise[kev].shape[0]
            # n-= min_completed

            print("Reserved:", temp_sum, ", narray", narray)
            rejections += 1

        for kev in z_noise:
            z_noise[kev] = torch.cat(z_noise[kev], 0)
            z_gum[kev] = torch.cat(z_gum[kev], 0)

        posterior_label, posterior_latent, gumbel_noise = {}, {}, {}
        for kev in z_noise:
            posterior_label[kev], posterior_latent[kev], gumbel_noise[kev] = reshape_noises(Exp, z_noise[kev], z_gum[kev])


        for idx, kev in enumerate(posterior_label):
            check_if_working(Exp, label_generators, posterior_label[kev], posterior_latent[kev], gumbel_noise[kev], evidence_list[idx],
                         z_noise[kev].shape[0])

    return posterior_label, posterior_latent, gumbel_noise

